<!-- Copyright 2019 Google LLC. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================-->

<html>

<head>
  <title>TensorFlow.js Model Benchmark</title>
  <link href="https://fonts.googleapis.com/css?family=Roboto" rel="stylesheet">
  <link href="./main.css" rel="stylesheet">
  <script src="https://cdn.jsdelivr.net/npm/lil-gui@0.18"></script>
  <script src="loader.js"></script>
  <link rel="shortcut icon"
        href="https://www.gstatic.com/devrel-devsite/prod/vf4ca28c48392b1412e7b030290622a0dd55b62dec1202c59f119b1e23227c988/tensorflow/images/favicon.png">
</head>

<body>
  <h2>TensorFlow.js Model Benchmark</h2>
  <div id="modal-msg"></div>
  <div id="container">
    <div id="stats">
      <div class="box">
        <pre id="env"></pre>
      </div>
      <table class="table" id="parameters">
        <caption>Parameters</caption>
        <thead>
          <tr>
            <th>Type</th>
            <th>Value</th>
          </tr>
        </thead>
        <tbody>
        </tbody>
      </table>
      <table class="table" id="modelInputs" style="display: none;">
        <caption>Inputs</caption>
        <thead>
          <tr>
            <th>Type</th>
            <th>Value</th>
          </tr>
        </thead>
        <tbody></tbody>
      </table>
      <table class="table" id="timings">
        <caption>Model</caption>
        <thead>
          <tr>
            <th>Type</th>
            <th>Value</th>
          </tr>
        </thead>
        <tbody>
        </tbody>
      </table>
      <div class="box" id="perf-trendline-container">
        <div class="label">Inference times</div>
        <div class="trendline">
          <div class="yMax"></div>
          <div class="yMin"></div>
          <svg>
            <path></path>
          </svg>
        </div>
      </div>
    </div>
    <table class="table" id="kernels">
      <caption>Kernel</caption>
      <thead id="kernels-thead">
        <tr>
          <th>Type</th>
          <th>Time(ms)</th>
        </tr>
      </thead>
      <tbody></tbody>
    </table>
  </div>
  <script>
    'use strict';

    let urlState = null;
    let runTimes = 50;
    let profileTimes = 1;
    let warmupTimes = 1;
    let paralllelCompile = false;
    // Default do not run any task.
    let task = '';
    function getURLState(url) {
      let params = new URLSearchParams(url);
      const keys = [...params.keys()];
      if (keys.length === 0)
        return null;
      if (params.has('run')) {
        runTimes = Number(params.get('run'));
      }
      if (params.has('profile')) {
        profileTimes = Number(params.get('profile'));
      }
      if (params.has('warmup')) {
        warmupTimes = Number(params.get('warmup'));
      }
      if (params.has('task')) {
        task = params.get('task');
      }
      if (params.has('parallelCompile')) {
        paralllelCompile = Boolean(params.get('parallelCompile'));
      }
      return params;
    }
    // Controllers can be updated by URI or clicked by user.
    let modelController = null;
    let numWarmupsController = null;
    let numRunsController = null;
    let numProfilesController = null;
    let parallelCompileController = null;
    let modelParameterFolder = null;
    let runButton = null;
    let correctnessButton = null;
    let inputSizeController = null;
    let inputTypeController = null;
    let modelArchitectureController = null;
    let tunableFlagsControllers = null;

    function updateGUIFromURLState() {
      if (urlState == null)
        return;
      if (tunableFlagsControllers != null) {
        const urlFlagStrings = ['WEBGL_USE_SHAPES_UNIFORMS',
                                'KEEP_INTERMEDIATE_TENSORS',
                                'CHECK_COMPUTATION_FOR_ERRORS'];
        for (let i = 0; i < urlFlagStrings.length; i++) {
          if (tunableFlagsControllers[urlFlagStrings[i]] != null &&
              urlState.has(urlFlagStrings[i])) {
            const value = urlState.get(urlFlagStrings[i]);
            if (value === 'true') {
              tunableFlagsControllers[urlFlagStrings[i]].setValue(true);
            } else {
              tunableFlagsControllers[urlFlagStrings[i]].setValue(false);
            }
          }
        }

        if (tunableFlagsControllers['WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE'] != null &&
          urlState.has('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE')) {
          const value = Number(urlState.get('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE'));
          tunableFlagsControllers['WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE'].setValue(value);
        }
      }
      // The model parameter will be update automically by onChange.
      if (modelController != null && urlState.has('benchmark'))
        modelController.setValue(urlState.get('benchmark'));
      if (numWarmupsController != null && urlState.has('warmup'))
        numWarmupsController.setValue(urlState.get('warmup'));
      if (numRunsController != null && urlState.has('run'))
        numRunsController.setValue(urlState.get('run'));
      if (numProfilesController != null && urlState.has('profile'))
        numProfilesController.setValue(urlState.get('profile'));
      // Only auto run when URI component exists.
      if (correctnessButton != null && task == 'correctness') {
        correctnessButton.$button.click();
      }
      if (runButton != null && task == 'performance') {
        runButton.$button.click();
      }
    }

    async function predictAndGetData(predict, model, inferenceInput, enableDump) {
      const prediction = await predict(model, inferenceInput);
      let intermediateData = {};
      if (enableDump) {
        const graphModel = getGraphModel(model);
        if (graphModel) {
          intermediateData = await getIntermediateTensorInfo(graphModel.getIntermediateTensors());
          graphModel.disposeIntermediateTensors();
        }
      }
      const predictionData = await getPredictionData(prediction);
      return {data: predictionData, intermediateData};
    }

    const state = {
      numWarmups: warmupTimes,
      numRuns: runTimes,
      numProfiles: profileTimes,
      parallelCompile: paralllelCompile,
      benchmark: 'MobileNetV3',
      run: (v) => {
        runBenchmark().catch(e => {
          showMsg('Error: ' + e.message);
          throw e;
        });
      },
      testCorrectness: async () => {
        if (!tf.engine().backendNames().includes(state.backend)) {
          return;
        }
        await tf.ready();
        if (state.isFlagChanged) {
          await setEnvFlags(state.flags);
          envDiv.innerHTML = '';
          showVersions();
          await showEnvironment();
          await showGpuInfo();
        }

        let match, actualData, expectedData;
        await cleanUpTable();

        // load model and run inference
        try {
          const expectedBackend = 'cpu';
          tf.setBackend(expectedBackend);
          await loadModelAndRecordTime();
          await showMsg('Testing correctness');
          await showInputs();
          await showCorrectnessTestParameters();

          let inferenceInput;
          await showMsg(`Runing on ${expectedBackend}`);
          if (state.benchmark === 'custom') {
            inferenceInput = generateInputFromDef(
              state.inputs, model instanceof tf.GraphModel);
          }

          let keepIntermediateTensors = false;
          try {
            keepIntermediateTensors = tf.env().getBool('KEEP_INTERMEDIATE_TENSORS');
          } catch (e) {
            console.warn(e.message);
          }

          const enableDump = keepIntermediateTensors & (benchmarks[state.benchmark].supportDump !== false);
          const expectedResult = await predictAndGetData(predict, model, inferenceInput, enableDump);
          expectedData = expectedResult['data'];

          await tf.setBackend(state.backend);
          await showMsg(`Runing on ${state.backend}`);

          // willReadFrequently is set to true for CPU backend. So to compare
          // the result, we also set it to true for GPU backends.
          let savedWillReadFrequently;
          if (state.backend === 'webgl' || state.backend === 'webgpu') {
            savedWillReadFrequently = tf.env().getBool('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU');
            tf.env().set('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU', true);
          }

          const actualResult = await predictAndGetData(predict, model, inferenceInput, enableDump);
          actualData = actualResult['data'];
          if (enableDump) {
            const actualIntermediateObject = actualResult['intermediateData'];
            const expectedIntermediateObject = expectedResult['intermediateData'];

            const dumpLevel = urlState.has('dumpLevel') ? Number(urlState.get('dumpLevel')) : 0;
            const dumpLength = urlState.has('dumpLength') ? Number(urlState.get('dumpLength')) : 1;
            const dumpPrefix = state.benchmark + '_'+ state.architecture + '_' + state.inputType + '_' + state.inputSize;
            const dumpInput = {[state.backend] : actualIntermediateObject, [expectedBackend] : expectedIntermediateObject};
            await dump(model, dumpInput, dumpPrefix, dumpLevel, dumpLength);
          }
          if (state.backend === 'webgl' || state.backend === 'webgpu') {
            tf.env().set('CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU', savedWillReadFrequently);
          }
        } catch (e) {
          showMsg('Error: ' + e.message);
          throw e;
        }

        // compare results
        try {
          await showMsg(null);
          expectObjectsClose(actualData, expectedData);
          match = true;
        } catch (e) {
          match = false;
          throw e;
        } finally {
          appendRow(timeTable, `backend`, state.backend);
          appendRow(timeTable, `Prediction matches CPU`, match);
          appendRow(timeTable, '', '');
        }
      },
      backend: 'webgl',
      kernelTiming: 'aggregate',
      inputSize: 0,
      inputType: '',
      architecture: 'small_075',
      modelType: '',
      modelUrl: '',
      isModelChanged: false,
      flags: {},
      isFlagChanged: false,
      isModelLoaded: false,
      inputs: {},
      shouldReloadTfliteModel: false
    };

    const modalDiv = document.getElementById('modal-msg');
    const timeTable = document.querySelector('#timings tbody');
    const envDiv = document.getElementById('env');
    const kernelsTableHead = document.getElementById('kernels-thead');
    const kernelTable = document.querySelector('#kernels tbody');
    const parameterTable = document.querySelector('#parameters tbody');
    const modelInputsTable = document.querySelector('#modelInputs tbody');

    let model, tfliteModel, worker, tfliteWorkerAPI, predict, chartWidth;

    async function showMsg(message) {
      if (message != null) {
        modalDiv.innerHTML = message + '...';
        modalDiv.style.display = 'block';
      } else {
        modalDiv.style.display = 'none';
      }
      await tf.nextFrame();
      await tf.nextFrame();
    }

    function showVersions() {
      envDiv.innerHTML = JSON.stringify({
        core: tf.version_core,
        layers: tf.version_layers,
        converter: tf.version_converter
      }, null, 2);
    }

    async function showGpuInfo() {
      const gpuHardwareInfo = await getRendererInfo();
      envDiv.innerHTML += `<br>{<br>  GPU info: \n  ${gpuHardwareInfo}<br>}<br>`;
    }

    async function showEnvironment() {
      await tf.time(() => tf.add(tf.tensor1d([1]), tf.tensor1d([1])).data());
      const sortedFeatures = Object.keys(tf.env().features).sort().map(
        name => `  ${name}: ${tf.env().features[name]}`);
      envDiv.innerHTML += `<br/>{<br>${sortedFeatures.join('<br>')}<br>}`;
    }

    function showBenchmarkingParameters() {
      appendRow(parameterTable, 'task', 'Performance Benchmark');
      appendRow(parameterTable, 'benchmark', state.benchmark);
      appendRow(parameterTable, 'modelType', state.modelType || 'Unknown');
      appendRow(parameterTable, 'convertedBy', model?.artifacts?.convertedBy || 'Unknown');
      if (state.benchmark === 'custom') {
        appendRow(parameterTable, 'modelUrl', state.modelUrl);
      }
      appendRow(parameterTable, 'numRuns', state.numRuns);
      appendRow(parameterTable, 'numProfiles', state.numProfiles);
      appendRow(parameterTable, 'backend', state.backend);
      appendRow(parameterTable, 'kernelTiming', state.kernelTiming);
      if (isParameterDefined('inputSizes')) {
        appendRow(parameterTable, 'inputSize', state.inputSize);
      }
      if (isParameterDefined('inputTypes')) {
        appendRow(parameterTable, 'inputType', state.inputType);
      }
      if (isParameterDefined('architectures')) {
        appendRow(parameterTable, 'architecture', state.architecture);
      }
    }

    async function setupKernelTable() {
      kernelsTableHead.innerText = '';
      kernelTable.innerHTML = '';
      const rows = ['<b>Kernel</b>', '<b>Time(ms)</b>'];
      if (state.kernelTiming === 'individual') {
        rows.push('<b>Inputs</b>', '<b>Output</b>');
        if (state.backend === 'webgl') {
          rows.push('<b>GPUPrograms</b>');
        }
      }
      appendRow(kernelsTableHead, ...rows);

      await tf.nextFrame();
    }

    async function showCorrectnessTestParameters() {
      appendRow(parameterTable, 'task', 'Correctness Test');
      appendRow(parameterTable, 'benchmark', state.benchmark);
      appendRow(parameterTable, 'modelType', state.modelType || 'Unknown');
      if (state.benchmark === 'custom') {
        appendRow(parameterTable, 'modelUrl', state.modelUrl);
      }
      appendRow(parameterTable, 'numRuns', 1);
      appendRow(parameterTable, 'numProfiles', 1);
      appendRow(parameterTable, 'backend', state.backend);

      await tf.nextFrame();
    }

    function appendRow(tbody, ...cells) {
      const tr = document.createElement('tr');
      cells.forEach(c => {
        const td = document.createElement('td');
        if (c instanceof HTMLElement) {
          td.appendChild(c);
        } else {
          td.innerHTML = c;
        }
        tr.appendChild(td);
      });
      tbody.appendChild(tr);
    }

    async function measureFirstPredictTime() {
      let predictFn;
      if (state.benchmark === 'custom') {
        const input = generateInputFromDef(
            state.inputs, model instanceof tf.GraphModel);
        predictFn = isTflite() ?
            () => predict(undefined, input) :
            getPredictFnForModel(model, input);
      } else {
        predictFn = () => predict(model);
      }

      const firstInferenceTime = await timeFirstInference(
          predictFn, state.parallelCompile);
      const tag = '1st inference time' +
          (state.parallelCompile ? ' (parallel compile)' : '');
      appendRow(timeTable, 'backend', state.backend);
      appendRow(timeTable, tag, printTime(firstInferenceTime));
    }

    async function warmUp() {
      const numWarmups = state.numWarmups;
      if (numWarmups == 0) {
        await showMsg('Skip warming up');
        appendRow(timeTable, 'backend', state.backend);
        appendRow(timeTable, 'Warmup time', '0 ms');
        return;
      }
      await showMsg(`Warming up for ${numWarmups} time(s)`);

      if (state.benchmark === 'custom') {
        const input = generateInputFromDef(state.inputs, model instanceof tf.GraphModel);
        try {
          isTflite() ?
              await timeInference(() => predict(undefined, input), numWarmups) :
              await timeModelInference(model, input, numWarmups);
        } finally {
          tf.dispose(input);
        }
      } else {
        await timeInference(() => predict(model), numWarmups);
      }
      await showMsg(null);
    }

    async function showInputs() {
      if (modelInputsTable.innerHTML !== '') {
        // the input table for the model have been constructed
        return;
      }
      let needInputShape = false;
      const modelInputsElement = document.getElementById('modelInputs');
      const inputs = isTflite() ? await tfliteModel.inputs : model.inputs;
      if (inputs == null) {
        modelInputsElement.style.display = 'none';
      } else {
        modelInputsElement.style.display = '';
        // construct the input table for the new model
        state.inputs.forEach((modelInput, inputIndex) => {
          if (modelInput.shape != null) {
            let shape = modelInput.shape;
            let shapeName = modelInput.name.concat('Shape');
            let variableShape = modelInput.shape.some(e => e == -1);
            if (variableShape && state.benchmark === 'custom') {
              const shapeInput = document.createElement("INPUT");
              shapeInput.setAttribute('type', 'text');
              if (urlState != null && urlState.has(shapeName)) {
                shapeInput.setAttribute('value', urlState.get(shapeName));
                state.inputs[inputIndex].shape =
                  shapeInput.value.split(',').map(Number);
              } else {
                shapeInput.setAttribute('value', shape);
                needInputShape = true;
              }
              shapeInput.addEventListener('focusout', () => {
                state.inputs[inputIndex].shape =
                  shapeInput.value.split(',').map(Number);
              });
              shape = shapeInput;
            }
            appendRow(modelInputsTable, 'name', modelInput.name);
            appendRow(modelInputsTable, 'shape', shape);
            appendRow(modelInputsTable, 'data type', modelInput.dtype);
            // add data range for input generator
            if (state.benchmark === 'custom') {
              const minInput = createRangeInput(state.inputs[inputIndex].range,
                0);
              const maxInput = createRangeInput(state.inputs[inputIndex].range,
                1);
              const seperator = document.createElement('span');
              seperator.textContent = ' - ';
              const div = document.createElement('div');
              div.appendChild(minInput);
              div.appendChild(seperator);
              div.appendChild(maxInput);
              appendRow(modelInputsTable, 'data range', div);
            }
            appendRow(modelInputsTable, '', '');
          }
        });
      }

      await tf.nextFrame();
      return needInputShape;
    }

    function createRangeInput(range, index) {
      const input = document.createElement("INPUT");
      input.setAttribute('type', 'text');
      input.setAttribute('value', range[index]);
      input.addEventListener('focusout', () => {
        range[index] = +input.value;
      });
      return input;
    }

    function updateStateFromURLState() {
      if (urlState == null)
        return;
      if (urlState.has('run'))
        state.numRuns = Number(urlState.get('run'));
      if (urlState.has('profile'))
        state.numProfiles = Number(urlState.get('profile'));
      if (urlState.has('benchmark'))
        state.benchmark = urlState.get('benchmark');
      if (urlState.has('modelUrl'))
        state.modelUrl = urlState.get('modelUrl');
      if (urlState.has('backend'))
        state.backend = urlState.get('backend');
      if (urlState.has('inputSize'))
        state.inputSize = Number(urlState.get('inputSize'));
      if (urlState.has('inputType'))
        state.inputType = urlState.get('inputType');
      if (urlState.has('architecture'))
        state.architecture = urlState.get('architecture');
    }

    async function loadModelAndRecordTime() {
      updateStateFromURLState();
      const benchmark = benchmarks[state.benchmark];
      state.modelType = benchmark['type'];
      state.isModelChanged = false; // used to clean the performance history

      if (benchmark['load'] == null) {
        throw new Error(`Please provide a load method for '${state.benchmark}' model.`);
      }
      if ((benchmark['loadTflite'] == null) && isTflite()) {
        throw new Error(`Please provide a loadTflite method for '${state.benchmark}' model.`);
      }

      await showMsg('Loading the model');
      let start = performance.now();
      const inputSize = parseFloat(state.inputSize);
      if (isTflite()) {
        await loadTfliteModel();
      } else {
        model = await benchmark.load(inputSize, state.architecture, state.inputType);
      }
      state.inputs = [];
      const inputs = isTflite() ? await tfliteModel.inputs : model.inputs;
      if (inputs) {
        // construct the input state for the model
        for (let modelInputIndex = 0; modelInputIndex < inputs.length; modelInputIndex++) {
          let modelInput = inputs[modelInputIndex];
          if (modelInput.shape == null) continue;
          let shape = modelInput.shape.map(e => e == null ? -1 : e);
          state.inputs.push({
            name: modelInput.name,
            shape: [...shape],
            dtype: modelInput.dtype,
            range: [0, 1000]
          });
        }
      }
      predict = benchmark.predictFunc(inputSize);
      const elapsed = performance.now() - start;

      await showMsg(null);
      appendRow(timeTable, 'Model load', printTime(elapsed));
    }

    const chartHeight = 150;
    function populateTrendline(node, data, forceYMinToZero = false, yFormatter = d => d) {
      node.querySelector('svg').setAttribute('width', chartWidth);
      node.querySelector('svg').setAttribute('height', chartHeight);

      const yMax = Math.max(...data);
      let yMin = forceYMinToZero ? 0 : Math.min(...data);
      if (yMin === yMax) {
        yMin = 0;
      }

      node.querySelector('.yMin').textContent = yFormatter(yMin);
      node.querySelector('.yMax').textContent = yFormatter(yMax);
      let xIncrement = 0;
      if (data.length != 1) {
        xIncrement = chartWidth / (data.length - 1);
      }

      node.querySelector('path')
        .setAttribute('d', `M${data.map((d, i) => `${i * xIncrement},${chartHeight - ((d - yMin) / (yMax - yMin)) * chartHeight}`).join('L')} `);
    }

    async function measureAveragePredictTime() {
      if (state.numRuns == 0) {
        return;
      }
      await showMsg(`Running predict ${state.numRuns} times`);
      chartWidth = document.querySelector('#perf-trendline-container').getBoundingClientRect().width;

      let timeInfo;
      const numRuns = state.numRuns;
      if (state.benchmark === 'custom') {
        const input = generateInputFromDef(state.inputs, model instanceof tf.GraphModel);
        try {
          timeInfo = isTflite() ?
              await timeInference(()=>predict(undefined, input), numRuns) :
              await timeModelInference(model, input, numRuns);
        } finally {
          tf.dispose(input);
        }
      } else {
        timeInfo = await timeInference(() => predict(model), numRuns);
      }

      const forceInferenceTrendYMinToZero = true;
      populateTrendline(document.querySelector('#perf-trendline-container'), timeInfo.times, forceInferenceTrendYMinToZero, printTime);

      await showMsg(null);
      appendRow(timeTable, `Subsequent average(${state.numRuns} runs)`, printTime(timeInfo.averageTime));
      appendRow(timeTable, 'Best time', printTime(timeInfo.minTime));
      // Add an empty row at the end of a round of benchmark.
      appendRow(timeTable, '', '');
    }

    async function profileMemoryAndKernelTime() {
      if (state.numProfiles == 0) {
        return;
      }
      await showMsg('Profile memory and kernels');

      // Reload tflite model with profiling enabled.
      if (isTflite()) {
        await loadTfliteModel(true);
        // This will make sure the model will be reloaded (with profiling disabled) next time
        // when "Run benchmark" is clicked.
        //
        // Without this, the model with profiling *enabled* will continued to be used to measure
        // the total run time when "Run benchmark" is clicked, which is not accurate.
        state.shouldReloadTfliteModel = true;
      }

      const start = performance.now();

      let profileInfo;
      const numProfiles = state.numProfiles;
      if (state.benchmark === 'custom') {
        const input = generateInputFromDef(state.inputs, model instanceof tf.GraphModel);
        try {
          if (isTflite()) {
            profileInfo = await profileInference(() => predict(undefined, input), true, numProfiles);
          } else {
            profileInfo = await profileModelInference(model, input, numProfiles);
          }
        } finally {
          tf.dispose(input);
        }
      } else {
        profileInfo = await profileInference(() => predict(model), isTflite(), numProfiles);
      }

      const elapsed = performance.now() - start;
      await showMsg(null);
      // TODO: These are not supported by tflite models yet.
      if (!isTflite()) {
        appendRow(timeTable, 'Peak memory', printMemory(profileInfo.peakBytes));
        appendRow(timeTable, 'Leaked tensors', profileInfo.newTensors);
      }
      appendRow(timeTable, 'Profile time', printTime(elapsed / numProfiles));

      if (state.backend === 'webgl' && !queryTimerIsEnabled()) {
        await showMsg('Query timer extension is not available. <br/> ' +
          'The kernel result is not accurate since it is ' +
          'not a direct measure of GPU execution time.');
      }

      await showMsg('Profiling kernels');
      appendRow(timeTable, 'Number of kernels', profileInfo.kernels.length);
      showKernelTime(profileInfo);
      await showMsg(null);
    }

    function showKernelTime(profileInfo) {
      const tbody = document.querySelector('#kernels tbody');
      if (state.kernelTiming === 'individual') {
        profileInfo.kernels.forEach(kernel => {
          const nameSpan = document.createElement('span');
          nameSpan.setAttribute('title', kernel.name);
          nameSpan.textContent = kernel.name;
          let inputInfo;
          kernel.inputShapes.forEach((inputShape, index) => {
            if (inputInfo == null) {
              inputInfo = '';
            } else {
              inputInfo += '<br>';
            }
            // TODO: Add the name for each input node in the returned value of
            // `tf.profile` at https://github.com/tensorflow/tfjs/blob/master/tfjs-core/src/engine.ts#L657.
            if (inputShape == null) {
              inputInfo += `input${index}: null`;
            } else {
              inputInfo += `input${index}: ${inputShape.length}D[${inputShape}]`;
            }
          });
          appendRow(tbody, nameSpan, kernel.kernelTimeMs.toFixed(2),
                    inputInfo || '-',
                    kernel.outputShapes.length !== 0 ? kernel.outputShapes : '-',
                    kernel.extraInfo || '-');
        });
      } else {
        profileInfo.aggregatedKernels.forEach(r => {
          const nameSpan = document.createElement('span');
          nameSpan.setAttribute('title', r.name);
          nameSpan.textContent = r.name;
          appendRow(tbody, nameSpan, r.timeMs.toFixed(2));
        });
      }
    }

    async function cleanUpTable() {
      parameterTable.innerHTML = '';
      if (state.isModelChanged) {
        timeTable.innerHTML = '';
        kernelTable.innerHTML = '';
        modelInputsTable.innerHTML = '';
      }
      await tf.nextFrame();
    }

    async function loadModel() {
      if (state.isModelChanged || !state.isModelLoaded || state.isFlagChanged ||
          (isTflite() && state.shouldReloadTfliteModel)) {
        await loadModelAndRecordTime();
        state.isModelLoaded = true;
        const needInputShape = await showInputs();
        if (needInputShape) {
          showMsg('Please set the variable input shapes, update -1 to the desired value.');
          return false;
        }
      }
      return true;
    }

    async function runBenchmark() {
      if (!tf.engine().backendNames().includes(state.backend) && !isTflite()) {
        return;
      }
      await tf.ready();
      if (state.isFlagChanged) {
        const canProceed = await setEnvFlags(state.flags);
        envDiv.innerHTML = '';
        showVersions();
        await showEnvironment();
        await showGpuInfo();
        if (!canProceed) {
          return;
        }
      }
      await cleanUpTable();
      await setupKernelTable();

      const canProceed = await loadModel();
      if (state.isFlagChanged) {
        state.isFlagChanged = false;
      }
      // return immediately if load model failed or need user input.
      if (!canProceed) {
        return;
      }
      await showBenchmarkingParameters();

      await measureFirstPredictTime();
      await warmUp();
      await measureAveragePredictTime();
      await profileMemoryAndKernelTime();
      urlState = null;
    }

    function isParameterDefined(parameter) {
      const benchmark = benchmarks[state.benchmark];
      if (benchmark[parameter] != null) {
        return true;
      }
      // Model doesn't support this parameter.
      return false;
    }

    function isURLParameterDefined(parameter) {
      if (urlState != null && urlState.has(parameter)) {
        return true;
      }
      // URL doesn't support this parameter.
      return false;
    }

    async function onPageLoad() {
      const gui = new lil.GUI({ width: 350 });
      gui.domElement.id = 'gui';

      urlState = getURLState(location.search);
      let localBuild = [];
      if (urlState && urlState.has('localBuild')) {
        localBuild = urlState.get('localBuild').split(',');
      }

      await loadTFJS(localBuild);

      if (urlState && urlState.has('backend'))
        state.backend = urlState.get('backend');
      await tf.setBackend(state.backend);

      const modelFolder = gui.addFolder('Benchmark');
      let modelUrlController = null;
      modelController = modelFolder.add(state, 'benchmark', Object.keys(benchmarks));
      modelController.name('models').onChange(async model => {
        await showMsg(null);
        const benchmark = benchmarks[state.benchmark];
        state.isModelChanged = true;
        state.isModelLoaded = false;
        if (modelArchitectureController !== null) {
          modelArchitectureController.destroy();
          modelArchitectureController = null;
        }
        if (inputSizeController !== null) {
          inputSizeController.destroy();
          inputSizeController = null;
        }
        if (inputTypeController !== null) {
          inputTypeController.destroy();
          inputTypeController = null;
        }

        if (isParameterDefined('inputSizes')) {
          inputSizeController = modelParameterFolder.add(state, 'inputSize', benchmark['inputSizes']).name('inputSizes').onChange(async inputSize => {
            state.inputSize = inputSize;
            state.isModelChanged = true;
          });
          // Current use first value as default.
          let defaultInputSize = 0;
          if (isURLParameterDefined('inputSize')) {
            defaultInputSize = urlState.get('inputSize');
          } else {
            defaultInputSize = benchmark['inputSizes'][0];
          }
          inputSizeController.setValue(defaultInputSize);
          state.inputSize = defaultInputSize;
        } else {
          // Model doesn't support input size.
          state.inputSize = 0;
        }

        if (isParameterDefined('architectures')) {
          modelArchitectureController = modelParameterFolder.add(state, 'architecture', benchmark['architectures']).name('architectures').onChange(async architecture => {
            state.architecture = architecture;
            state.isModelChanged = true;
          });
          // Current use first value as default.
          let defaultModelArchitecture = null;
          if (isURLParameterDefined('architecture'))
            defaultModelArchitecture = urlState.get('architecture');
          else
            defaultModelArchitecture = benchmark['architectures'][0];

          modelArchitectureController.setValue(defaultModelArchitecture);
          state.architecture = defaultModelArchitecture;
        } else {
          // Model doesn't support architecture.
          state.architecture = '';
        }

        if (isParameterDefined('inputTypes')) {
          inputTypeController = modelParameterFolder.add(state, 'inputType', benchmark['inputTypes']).name('inputTypes').onChange(async inputType => {
            state.inputType = inputType;
            state.isModelChanged = true;
          });
          // Current use first value as default.
          let defaultInputType = null;
          if (isURLParameterDefined('inputType')) {
            defaultInputType = urlState.get('inputType');
          } else {
            defaultInputType = benchmark['inputTypes'][0];
          }
          inputTypeController.setValue(defaultInputType);
          state.inputType = defaultInputType;
        } else {
          // Model doesn't support input type.
          state.inputType = '';
        }

        // Unfolding the model parameter UI if any model parameters are deinfed
        // in the current model.
        if (isParameterDefined('inputSizes') || isParameterDefined('inputTypes') || isParameterDefined('architectures')) {
          modelParameterFolder.open();
        }

        if (model === 'custom') {
          if (modelUrlController === null) {
            modelUrlController = modelFolder.add(state, 'modelUrl').onChange(async modelUrl => {
              state.isModelChanged = true;
              state.isModelLoaded = false;
              await showMsg(null);
            });
            modelUrlController.domElement.querySelector('input').placeholder =
              'https://your-domain.com/model-path/model.json';
            if (modelUrlController != null && urlState && urlState.has('modelUrl')) {
              modelUrlController.setValue(urlState.get('modelUrl'));
            }
          }
        } else if (modelUrlController != null) {
          modelUrlController.destroy();
          modelUrlController = null;
        }
      });
      modelFolder.open();

      const parameterFolder = gui.addFolder('Parameters');
      numWarmupsController = parameterFolder.add(state, 'numWarmups');
      numRunsController = parameterFolder.add(state, 'numRuns');
      numProfilesController = parameterFolder.add(state, 'numProfiles');
      parameterFolder.add(state, 'kernelTiming', ['aggregate', 'individual']);
      parallelCompileController = parameterFolder.add(state, 'parallelCompile');
      parameterFolder.open();

      // Show model parameter UI when loading the page.
      modelParameterFolder = gui.addFolder('Model Parameters');
      // For each model parameter, show it only if it is defined in the
      // pre-selected model.
      if (isParameterDefined('architectures')) {
        modelArchitectureController = modelParameterFolder.add(state, 'architecture', benchmarks[state.benchmark]['architectures']);
        modelArchitectureController.setValue(state.architecture);
      }
      if (isParameterDefined('inputSizes')) {
        inputSizeController = modelParameterFolder.add(state, 'inputSize', benchmarks[state.benchmark]['inputSizes']);
        inputSizeController.setValue(state.inputSize);
      }
      if (isParameterDefined('inputTypes')) {
        inputTypeController = modelParameterFolder.add(state, 'inputType', benchmarks[state.benchmark]['inputTypes']);
        inputTypeController.setValue(state.inputType);
      }

      // Unfolding the model parameter UI if any model parameters are deinfed
      // in the pre-selected model.
      if (isParameterDefined('inputSizes') || isParameterDefined('inputTypes') || isParameterDefined('architectures')) {
          modelParameterFolder.open();
      }

      const envFolder = gui.addFolder('Environment');
      const backendsController = envFolder.add(
        state, 'backend', ['wasm', 'webgl', 'cpu', 'webgpu', 'tflite']);
      backendsController.onChange(async backend => {
        // TODO: Failed to set backend
        if (state.backend === 'webgpu' && !tf.engine().backendNames().includes(state.backend)) {
          showMsg(
            `WebGPU is not supported. Please use Chrome Canary browser with flag "--enable-unsafe-webgpu" enabled.`);
          return;
        } else {
          showMsg(null);
        }
        if (isTflite()) {
          // Remove the "Test correctness" button which doesn't apply to tfjs-tflite.
          correctnessButton.destroy();
          correctnessButton = null;
          parallelCompileController.destroy();
          parallelCompileController = null;

          // Update models drop down to only show models that support tflite.
          const tfliteBenchmarks = Object.keys(benchmarks)
            .filter(name => benchmarks[name].loadTflite);
          updateModelsDropdown(tfliteBenchmarks);
          modelController.setValue(tfliteBenchmarks[0]);
          state.isModelLoaded = false;
          state.isModelChanged = true;
        } else {
          // Show the "Test correctness" button and re-populate models dropdown when changed from
          // tflite to non-tflite backend.
          if (correctnessButton == null) {
            correctnessButton = gui.add(state, 'testCorrectness');
            parallelCompileController = parameterFolder.add(state, 'parallelCompile');
            const allBenchmarks = Object.keys(benchmarks);
            updateModelsDropdown(allBenchmarks);
            modelController.setValue(allBenchmarks[0]);
            state.isModelLoaded = false;
            state.isModelChanged = true;
          }
          await tf.setBackend(backend);
        }
        await showFlagSettingsAndReturnTunableFlagControllers(envFolder, state.backend);
      });
      envFolder.open();

      runButton = gui.add(state, 'run');
      runButton.name('Run benchmark');
      correctnessButton = gui.add(state, 'testCorrectness');
      correctnessButton.name('Test correctness');

      tunableFlagsControllers = await showFlagSettingsAndReturnTunableFlagControllers(envFolder, state.backend);
      showVersions();
      await showEnvironment();
      await showGpuInfo();
      updateGUIFromURLState();
    }

    function isTflite() {
      return state.backend === 'tflite';
    }

    async function loadTfliteModel(enableProfiling = false) {
      if (!worker) {
        worker = new Worker('./tflite_worker.js');
        tfliteWorkerAPI = Comlink.wrap(worker);
      }
      if (tfliteModel) {
        await tfliteModel.cleanUp();
      }
      state.shouldReloadTfliteModel = false;
      const benchmark = benchmarks[state.benchmark];

      tfliteModel = await benchmark.loadTflite(enableProfiling, state.architecture);
    }

    function updateModelsDropdown(newValues) {
      modelController._names = newValues;
      modelController._values = newValues;
      const tfliteBenchmarksHtml = newValues.map(name => `<option value="${name}">${name}</option>`);
      modelController.domElement.children[1].children[0].innerHTML = tfliteBenchmarksHtml;
    }

    onPageLoad();
  </script>
</body>

</html>
